//! True Random Number Generator (TRNG) driver.

use core::future::poll_fn;
use core::marker::PhantomData;
use core::ops::Not;
use core::task::Poll;

use embassy_hal_internal::{Peri, PeripheralType};
use embassy_sync::waitqueue::AtomicWaker;
use rand_core::Error;

use crate::interrupt::typelevel::{Binding, Interrupt};
use crate::peripherals::TRNG;
use crate::{interrupt, pac};

trait SealedInstance {
    fn regs() -> pac::trng::Trng;
    fn waker() -> &'static AtomicWaker;
}

/// TRNG peripheral instance.
#[allow(private_bounds)]
pub trait Instance: SealedInstance + PeripheralType {
    /// Interrupt for this peripheral.
    type Interrupt: Interrupt;
}

impl SealedInstance for TRNG {
    fn regs() -> rp_pac::trng::Trng {
        pac::TRNG
    }

    fn waker() -> &'static AtomicWaker {
        static WAKER: AtomicWaker = AtomicWaker::new();
        &WAKER
    }
}

impl Instance for TRNG {
    type Interrupt = interrupt::typelevel::TRNG_IRQ;
}

#[derive(Copy, Clone, Debug)]
#[allow(missing_docs)]
/// TRNG ROSC Inverter chain length options.
pub enum InverterChainLength {
    None = 0,
    One,
    Two,
    Three,
    Four,
}

impl From<InverterChainLength> for u8 {
    fn from(value: InverterChainLength) -> Self {
        value as u8
    }
}

/// Configuration for the TRNG.
///
/// - Three built in entropy checks
/// - ROSC frequency controlled by selecting one of ROSC chain lengths
/// - Sample period in terms of system clock ticks
///
///
/// Default configuration is based on the following from documentation:
///
/// ----
///
/// RP2350 Datasheet 12.12.2
///
/// ...
///
/// When configuring the TRNG block, consider the following principles:
/// • As average generation time increases, result quality increases and failed entropy checks decrease.
/// • A low sample count decreases average generation time, but increases the chance of NIST test-failing results and
/// failed entropy checks.
/// For acceptable results with an average generation time of about 2 milliseconds, use ROSC chain length settings of 0 or
/// 1 and sample count settings of 20-25.
/// Larger sample count settings (e.g. 100) provide proportionately slower average generation times. These settings
/// significantly reduce, but do not eliminate NIST test failures and entropy check failures. Results occasionally take an
/// especially long time to generate.
///
/// ---
///
/// Note, Pico SDK and Bootrom don't use any of the entropy checks and sample the ROSC directly
/// by setting the sample period to 0. Random data collected this way is then passed through
/// either hardware accelerated SHA256 (Bootrom) or xoroshiro128** (version 1.0!).
#[non_exhaustive]
#[derive(Copy, Clone, Debug)]
pub struct Config {
    /// Bypass TRNG autocorrelation test
    pub disable_autocorrelation_test: bool,
    /// Bypass CRNGT test
    pub disable_crngt_test: bool,
    /// When set, the Von-Neuman balancer is bypassed (including the
    /// 32 consecutive bits test)
    pub disable_von_neumann_balancer: bool,
    /// Sets the number of rng_clk cycles between two consecutive
    /// ring oscillator samples.
    /// Note: If the von Neumann decorrelator is bypassed, the minimum value for
    /// sample counter must not be less than seventeen
    pub sample_count: u32,
    /// Selects the number of inverters (out of four possible
    /// selections) in the ring oscillator (the entropy source). Higher values select
    /// longer inverter chain lengths.
    pub inverter_chain_length: InverterChainLength,
}

impl Default for Config {
    fn default() -> Self {
        Config {
            // WARNING: Disabling these tests increases likelihood of poor rng results.
            disable_autocorrelation_test: false,
            disable_crngt_test: false,
            disable_von_neumann_balancer: false,
            sample_count: 25,
            inverter_chain_length: InverterChainLength::One,
        }
    }
}

/// True Random Number Generator Driver for RP2350
///
/// This driver provides async and blocking options.
///
/// See [Config] for configuration details.
///
/// Usage example:
/// ```no_run
/// use embassy_executor::Spawner;
/// use embassy_rp::trng::Trng;
/// use embassy_rp::peripherals::TRNG;
/// use embassy_rp::bind_interrupts;
///
/// bind_interrupts!(struct Irqs {
///     TRNG_IRQ => embassy_rp::trng::InterruptHandler<TRNG>;
/// });
///
/// #[embassy_executor::main]
/// async fn main(spawner: Spawner) {
///     let peripherals = embassy_rp::init(Default::default());
///     let mut trng = Trng::new(peripherals.TRNG, Irqs, embassy_rp::trng::Config::default());
///
///     let mut randomness = [0u8; 58];
///     loop {
///         trng.fill_bytes(&mut randomness).await;
///         assert_ne!(randomness, [0u8; 58]);
///     }
///}
/// ```
pub struct Trng<'d, T: Instance> {
    phantom: PhantomData<&'d mut T>,
    config: Config,
}

/// 12.12.1. Overview
/// On request, the TRNG block generates a block of 192 entropy bits generated by automatically processing a series of
/// periodic samples from the TRNG block’s internal Ring Oscillator (ROSC).
const TRNG_BLOCK_SIZE_BITS: usize = 192;
const TRNG_BLOCK_SIZE_BYTES: usize = TRNG_BLOCK_SIZE_BITS / 8;

impl<'d, T: Instance> Trng<'d, T> {
    /// Create a new TRNG driver.
    pub fn new(_trng: Peri<'d, T>, _irq: impl Binding<T::Interrupt, InterruptHandler<T>> + 'd, config: Config) -> Self {
        let trng = Trng {
            phantom: PhantomData,
            config: config,
        };
        trng.initialize_rng();
        trng
    }

    fn start_rng(&self) {
        let regs = T::regs();
        let source_enable_register = regs.rnd_source_enable();
        // Enable TRNG ROSC
        source_enable_register.write(|w| w.set_rnd_src_en(true));
    }

    fn stop_rng(&self) {
        let regs = T::regs();
        let source_enable_register = regs.rnd_source_enable();
        source_enable_register.write(|w| w.set_rnd_src_en(false));
        let reset_bits_counter_register = regs.rst_bits_counter();
        reset_bits_counter_register.write(|w| w.set_rst_bits_counter(true));
    }

    fn initialize_rng(&self) {
        let regs = T::regs();

        regs.rng_imr().write(|w| w.set_ehr_valid_int_mask(false));

        let trng_config_register = regs.trng_config();
        trng_config_register.write(|w| {
            w.set_rnd_src_sel(self.config.inverter_chain_length.clone().into());
        });

        let sample_count_register = regs.sample_cnt1();
        sample_count_register.write(|w| {
            *w = self.config.sample_count;
        });

        let debug_control_register = regs.trng_debug_control();
        debug_control_register.write(|w| {
            w.set_auto_correlate_bypass(self.config.disable_autocorrelation_test);
            w.set_trng_crngt_bypass(self.config.disable_crngt_test);
            w.set_vnc_bypass(self.config.disable_von_neumann_balancer);
        });
    }

    fn enable_irq(&self) {
        unsafe { T::Interrupt::enable() }
    }

    fn disable_irq(&self) {
        T::Interrupt::disable();
    }

    fn blocking_wait_for_successful_generation(&self) {
        let regs = T::regs();

        let trng_busy_register = regs.trng_busy();
        let trng_valid_register = regs.trng_valid();

        let mut success = false;
        while success.not() {
            while trng_busy_register.read().trng_busy() {}
            if trng_valid_register.read().ehr_valid().not() {
                if regs.rng_isr().read().autocorr_err() {
                    regs.trng_sw_reset().write(|w| w.set_trng_sw_reset(true));
                    // Fixed delay is required after TRNG soft reset. This read is sufficient.
                    regs.trng_sw_reset().read();
                    self.initialize_rng();
                    self.start_rng();
                } else {
                    panic!("RNG not busy, but ehr is not valid!")
                }
            } else {
                success = true
            }
        }
    }

    fn read_ehr_registers_into_array(&mut self, buffer: &mut [u8; TRNG_BLOCK_SIZE_BYTES]) {
        let regs = T::regs();
        let ehr_data_regs = [
            regs.ehr_data0(),
            regs.ehr_data1(),
            regs.ehr_data2(),
            regs.ehr_data3(),
            regs.ehr_data4(),
            regs.ehr_data5(),
        ];

        for (i, reg) in ehr_data_regs.iter().enumerate() {
            buffer[i * 4..i * 4 + 4].copy_from_slice(&reg.read().to_ne_bytes());
        }
    }

    fn blocking_read_ehr_registers_into_array(&mut self, buffer: &mut [u8; TRNG_BLOCK_SIZE_BYTES]) {
        self.blocking_wait_for_successful_generation();
        self.read_ehr_registers_into_array(buffer);
    }

    /// Fill the buffer with random bytes, async version.
    pub async fn fill_bytes(&mut self, destination: &mut [u8]) {
        if destination.is_empty() {
            return; // Nothing to fill
        }

        self.start_rng();
        self.enable_irq();

        let mut bytes_transferred = 0usize;
        let mut buffer = [0u8; TRNG_BLOCK_SIZE_BYTES];

        let regs = T::regs();

        let trng_busy_register = regs.trng_busy();
        let trng_valid_register = regs.trng_valid();

        let waker = T::waker();

        let destination_length = destination.len();

        poll_fn(|context| {
            waker.register(context.waker());
            if bytes_transferred == destination_length {
                self.stop_rng();
                self.disable_irq();
                Poll::Ready(())
            } else {
                if trng_busy_register.read().trng_busy() {
                    Poll::Pending
                } else {
                    // If woken up and EHR is *not* valid, assume the trng has been reset and reinitialize, restart.
                    if trng_valid_register.read().ehr_valid().not() {
                        self.initialize_rng();
                        self.start_rng();
                        return Poll::Pending;
                    }
                    self.read_ehr_registers_into_array(&mut buffer);
                    let remaining = destination_length - bytes_transferred;
                    if remaining > TRNG_BLOCK_SIZE_BYTES {
                        destination[bytes_transferred..bytes_transferred + TRNG_BLOCK_SIZE_BYTES]
                            .copy_from_slice(&buffer);
                        bytes_transferred += TRNG_BLOCK_SIZE_BYTES
                    } else {
                        destination[bytes_transferred..bytes_transferred + remaining]
                            .copy_from_slice(&buffer[0..remaining]);
                        bytes_transferred += remaining
                    }
                    if bytes_transferred == destination_length {
                        self.stop_rng();
                        self.disable_irq();
                        Poll::Ready(())
                    } else {
                        Poll::Pending
                    }
                }
            }
        })
        .await
    }

    /// Fill the buffer with random bytes, blocking version.
    pub fn blocking_fill_bytes(&mut self, destination: &mut [u8]) {
        if destination.is_empty() {
            return; // Nothing to fill
        }
        self.start_rng();

        let mut buffer = [0u8; TRNG_BLOCK_SIZE_BYTES];

        for chunk in destination.chunks_mut(TRNG_BLOCK_SIZE_BYTES) {
            self.blocking_wait_for_successful_generation();
            self.blocking_read_ehr_registers_into_array(&mut buffer);
            chunk.copy_from_slice(&buffer[..chunk.len()])
        }
        self.stop_rng()
    }

    /// Return a random u32, blocking.
    pub fn blocking_next_u32(&mut self) -> u32 {
        let regs = T::regs();
        self.start_rng();
        self.blocking_wait_for_successful_generation();
        // 12.12.3 After successful generation, read the last result register, EHR_DATA[5] to
        // clear all of the result registers.
        let result = regs.ehr_data5().read();
        self.stop_rng();
        result
    }

    /// Return a random u64, blocking.
    pub fn blocking_next_u64(&mut self) -> u64 {
        let regs = T::regs();
        self.start_rng();
        self.blocking_wait_for_successful_generation();

        let low = regs.ehr_data4().read() as u64;
        // 12.12.3 After successful generation, read the last result register, EHR_DATA[5] to
        // clear all of the result registers.
        let result = (regs.ehr_data5().read() as u64) << 32 | low;
        self.stop_rng();
        result
    }
}

impl<'d, T: Instance> rand_core::RngCore for Trng<'d, T> {
    fn next_u32(&mut self) -> u32 {
        self.blocking_next_u32()
    }

    fn next_u64(&mut self) -> u64 {
        self.blocking_next_u64()
    }

    fn fill_bytes(&mut self, dest: &mut [u8]) {
        self.blocking_fill_bytes(dest)
    }

    fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> {
        self.blocking_fill_bytes(dest);
        Ok(())
    }
}

impl<'d, T: Instance> rand_core::CryptoRng for Trng<'d, T> {}

/// TRNG interrupt handler.
pub struct InterruptHandler<T: Instance> {
    _trng: PhantomData<T>,
}

impl<T: Instance> interrupt::typelevel::Handler<T::Interrupt> for InterruptHandler<T> {
    unsafe fn on_interrupt() {
        let regs = T::regs();
        let isr = regs.rng_isr().read();
        if isr.ehr_valid() {
            regs.rng_icr().write(|w| {
                w.set_ehr_valid(true);
            });
            T::waker().wake();
        } else if isr.crngt_err() {
            warn!("TRNG CRNGT error! Increase sample count to reduce likelihood");
            regs.rng_icr().write(|w| {
                w.set_crngt_err(true);
            });
        } else if isr.vn_err() {
            warn!("TRNG Von-Neumann balancer error! Increase sample count to reduce likelihood");
            regs.rng_icr().write(|w| {
                w.set_vn_err(true);
            });
        } else if isr.autocorr_err() {
            // 12.12.5. List of Registers
            // ...
            // TRNG: RNG_ISR Register
            // ...
            // AUTOCORR_ERR: 1 indicates Autocorrelation test failed four times in a row.
            // When set, RNG ceases functioning until next reset
            warn!("TRNG Autocorrect error! Resetting TRNG. Increase sample count to reduce likelihood");
            regs.trng_sw_reset().write(|w| {
                w.set_trng_sw_reset(true);
            });
            // Fixed delay is required after TRNG soft reset, this read is sufficient.
            regs.trng_sw_reset().read();
            // Wake up to reinitialize and restart the TRNG.
            T::waker().wake();
        }
    }
}
